from __future__ import print_function
import matplotlib

matplotlib.use('Agg')

from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# flags
flags = tf.app.flags
FLAGS = flags.FLAGS

flags.DEFINE_bool("colored", False, "Use color MNIST")

# hyperparameters
num_steps = 10000
image_every = 100
noise_dims = 100
num_classes = 10
batch_size = 100
eval_every = 100
lrelu_alpha = 0.3
dg_list=[]
mm=[]
classifier_list=[]


def generator(inputs, scope="generator"):
    with tf.variable_scope(scope):
        xav_init = tf.contrib.layers.xavier_initializer
        bnorm = tf.layers.batch_normalization
        fc_1 = tf.layers.dense(
            inputs, kernel_initializer=xav_init(), units=7 * 7 * 64)
        reshaped_fc_1 = tf.reshape(fc_1, shape=[tf.shape(fc_1)[0], 7, 7, 64])
        deconv_1 = tf.layers.conv2d_transpose(
            inputs=reshaped_fc_1,
            filters=64,
            kernel_size=[5, 5],
            strides=2,
            padding="same",
            kernel_initializer=xav_init(),
            activation=tf.nn.relu)
        deconv_1 = bnorm(deconv_1)
        print("deconv_1: {}".format(deconv_1.shape))

        deconv_2 = tf.layers.conv2d_transpose(
            inputs=deconv_1,
            filters=1,
            kernel_size=[5, 5],
            strides=2,
            padding="same",
            kernel_initializer=xav_init(),
            activation=tf.nn.tanh)  # [-1, 1]
        deconv_2 = bnorm(deconv_2)
        print("deconv_2: {}".format(deconv_2.shape))
        return deconv_2


def discriminator(inputs, reuse=False, scope="discriminator"):
    with tf.variable_scope(scope, reuse=reuse):
        xav_init = tf.contrib.layers.xavier_initializer
        bnorm = tf.layers.batch_normalization

        layer_1 = tf.layers.conv2d(
            inputs=inputs,
            filters=32,
            kernel_size=[5, 5],
            kernel_initializer=xav_init())
        layer_1 = tf.maximum(layer_1, lrelu_alpha * layer_1)
        layer_1 = bnorm(layer_1)

        layer_2 = tf.layers.conv2d(
            inputs=layer_1,
            filters=64,
            kernel_size=[5, 5],
            kernel_initializer=xav_init())
        layer_2 = tf.maximum(layer_2, lrelu_alpha * layer_2)
        layer_2 = bnorm(layer_2)

        layer_3 = tf.layers.conv2d(
            inputs=layer_2,
            filters=128,
            kernel_size=[5, 5],
            kernel_initializer=xav_init())
        layer_3 = tf.maximum(layer_3, lrelu_alpha * layer_3)
        layer_3 = bnorm(layer_3)

        layer_4 = tf.layers.conv2d(
            inputs=layer_3,
            filters=128,
            kernel_size=[5, 5],
            kernel_initializer=xav_init())
        layer_4 = tf.maximum(layer_4, lrelu_alpha * layer_4)
        layer_4 = bnorm(layer_4)

        print("layer 4: {}".format(layer_4.shape))
        flatten_layer_4 = tf.layers.flatten(layer_4)
        logits = tf.layers.dense(inputs=flatten_layer_4, units=1)
        preds = tf.sigmoid(logits)

        return preds, logits


def classifier(inputs, labels, num_classes):
    with tf.variable_scope("classifier"):
        xav_init = tf.contrib.layers.xavier_initializer
        layer_1 = tf.layers.conv2d(
            inputs=inputs,
            filters=32,
            kernel_size=[3, 3],
            activation=tf.nn.relu,
            kernel_initializer=xav_init())
        layer_1 = tf.layers.max_pooling2d(
            layer_1, pool_size=[2, 2], strides=[2, 2])

        layer_2 = tf.layers.conv2d(
            inputs=layer_1,
            filters=64,
            kernel_size=[3, 3],
            activation=tf.nn.relu,
            kernel_initializer=xav_init())
        layer_2 = tf.layers.max_pooling2d(
            layer_2, pool_size=[2, 2], strides=[2, 2])

        layer_3 = tf.layers.conv2d(
            inputs=layer_2,
            filters=128,
            kernel_size=[3, 3],
            activation=tf.nn.relu,
            kernel_initializer=xav_init())
        layer_3 = tf.layers.max_pooling2d(
            layer_3, pool_size=[2, 2], strides=[2, 2])

        flatten_layer_3 = tf.layers.flatten(layer_3)
        logits = tf.layers.dense(inputs=flatten_layer_3, units=10)
        softmax_op = tf.nn.softmax(logits)

        correct_predictions = tf.equal(
            tf.argmax(softmax_op, 1), tf.argmax(labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

        cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        cross_entropy = tf.reduce_mean(cross_entropy, name="c_cross_entropy")
        return softmax_op, cross_entropy, accuracy


if __name__ == "__main__":

    noise = tf.placeholder(
        shape=[None, noise_dims], dtype=tf.float32, name="noise")
    images = tf.placeholder(
        shape=[None, 28, 28, 1], dtype=tf.float32, name="images")
    labels = tf.placeholder(shape=[None, 10], dtype=tf.int32, name="labels")

    if FLAGS.colored:
        images = tf.image.grayscale_to_rgb(images)
        # images is 0-255 --> -1-1
        zero_one_mask = (tf.to_float(images) / 255.)
        colors = tf.random_normal([images.shape[0], 1, 1, 3], mean=0, stddev=.5)

        scaled_images = (zero_one_mask * 1.0 - 1.0) + zero_one_mask * colors
        scaled_images += tf.random_normal(scaled_images.shape, stddev=0.2)
        scaled_images = tf.clip_by_value(scaled_images, -1, 1)
        images = scaled_images

    c_preds, c_cross_entropy, c_accuracy = classifier(images, labels,
                                                      num_classes)
    c_opt = tf.train.AdamOptimizer().minimize(c_cross_entropy)

    gen_op = generator(noise)

    d_real_preds, d_real_logits = discriminator(images)
    d_fake_preds, d_fake_logits = discriminator(gen_op, reuse=True)

    d_fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_fake_logits, labels=tf.zeros_like(d_fake_logits))
    d_fake_loss = tf.reduce_mean(d_fake_loss)

    d_real_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        logits=d_real_logits, labels=tf.ones_like(d_real_logits))
    d_real_loss = tf.reduce_mean(d_real_loss)

    d_loss = d_fake_loss + d_real_loss

    g_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits, labels=tf.ones_like(d_fake_logits))
    g_loss = tf.reduce_mean(g_loss)

    d_adam = tf.train.AdamOptimizer()
    g_adam = tf.train.AdamOptimizer()

    d_vars = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope="discriminator")
    g_vars = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope="generator")

    d_opt = d_adam.minimize(d_loss, var_list=d_vars)
    g_opt = g_adam.minimize(g_loss, var_list=g_vars)


    # Code for DG
    with tf.variable_scope('worst_calc', reuse=tf.AUTO_REUSE):
        new_opt = tf.train.AdamOptimizer(learning_rate=1e-3)
        d_real_preds_tmp, d_real_logits_tmp = discriminator(images, scope="discTMP")
        d_fake_preds_tmp, d_fake_logits_tmp = discriminator(gen_op, reuse=True, scope="discTMP")
        disc_loss_worst = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_real_logits_tmp, labels=tf.ones_like(d_real_logits_tmp))) + tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=d_fake_logits_tmp, labels=tf.zeros_like(d_fake_logits_tmp)))
        t_vars = tf.global_variables()
        d_vars_worst = [var for var in t_vars if 'discTMP' in var.name]
        find_worst_d = new_opt.minimize(disc_loss_worst, var_list=d_vars_worst)

    with tf.variable_scope('worst_calc_gen', reuse=tf.AUTO_REUSE):
        new_opt_gen = tf.train.AdamOptimizer(learning_rate=1e-3)
        x_w = generator(noise, "genTMP")

    d_real_preds_worst, d_real_logits_worst = discriminator(images, reuse=True)
    d_fake_preds_worst, d_fake_logits_worst = discriminator(x_w, reuse=True)

    with tf.variable_scope('worst_calc_gen', reuse=tf.AUTO_REUSE):
        gen_loss_worst = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits_worst,
                                                                                labels=tf.ones_like(
                                                                                    d_fake_logits_worst)))
        t_vars = tf.global_variables()
        g_vars_worst = [var for var in t_vars if 'genTMP' in var.name]
        find_worst_g = new_opt_gen.minimize(gen_loss_worst, var_list=g_vars_worst)

    t_vars = tf.global_variables()
    d_init = [var for var in t_vars if 'worst_calc' in var.name]
    init_new_vars_op = tf.initialize_variables(d_init)

    curr_to_tmp = []
    t_vars = tf.global_variables()
    d_vars_tmp = [var for var in t_vars if 'discTMP' in var.name and 'RMSProp' not in var.name]
    d_vars_0 = [var for var in t_vars if 'discriminator/' in var.name and 'RMSProp' not in var.name]
    g_vars_tmp = [var for var in t_vars if 'genTMP' in var.name and 'RMSProp' not in var.name]
    g_vars_0 = [var for var in t_vars if 'generator/' in var.name and 'RMSProp' not in var.name]
    for j in range(0, len(d_vars_tmp)):
        print(d_vars_tmp[j])
        curr_to_tmp.append(d_vars_tmp[j].assign(d_vars_0[j]))
    for j in range(0, len(g_vars_tmp)):
        curr_to_tmp.append(g_vars_tmp[j].assign(g_vars_0[j]))

    current_to_tmp = tf.group(*curr_to_tmp)

    mnist = input_data.read_data_sets(
        "MNIST_data", one_hot=True, reshape=False)

    plc_float = tf.placeholder(tf.float32)
    plc_float_r = tf.placeholder(tf.float32)
    disc_loss_calc2 = -tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=plc_float_r, labels=tf.ones_like(plc_float_r)) +
        tf.nn.sigmoid_cross_entropy_with_logits(logits=plc_float, labels=tf.zeros_like(plc_float)))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    for i in range(num_steps):

        batch_images, batch_labels = mnist.train.next_batch(batch_size)
        batch_images = (batch_images - 0.5) / 0.5
        batch_noise = np.random.randn(batch_size, noise_dims)
        feed_dict = {images: batch_images, labels: batch_labels}
        sess.run(c_opt, feed_dict={images: batch_images, labels: batch_labels})
        sess.run(d_opt, feed_dict={images: batch_images, labels: batch_labels, noise: batch_noise})
        sess.run(g_opt, feed_dict={noise: batch_noise})

        if i % image_every == 0:
            gen_noise = np.random.randn(batch_size, noise_dims)
            gen_images = sess.run(gen_op, feed_dict={noise: gen_noise})
            gen_images = gen_images.reshape(100, 28, 28)
            fig = plt.figure()
            grid = ImageGrid(fig, 111, nrows_ncols=(10, 10), axes_pad=0)
            for x in range(100):
                grid[x].imshow(gen_images[x], cmap="gray")
                grid[x].set_xticks([])
                grid[x].set_yticks([])
            plt.savefig(
                os.path.join("imagesCollapse", "step_{}.png".format(i)))
            plt.close()

        if i % eval_every == 0:
            batch_images, batch_labels = mnist.validation.next_batch(
                batch_size)
            feed_dict = {images: batch_images, labels: batch_labels}
            accuracy = sess.run(c_accuracy, feed_dict=feed_dict)
            print("Epoch: {}, accuracy: {}".format(i, accuracy))
            # First randomly initialize the new variables for the optimization of the new D_tmp/G_tmp
            sess.run(init_new_vars_op)
            # Assign the weights to the new D_tmp/G_tmp to be the those of the current D/G
            sess.run(current_to_tmp)

            # for fixed G, find the worst D_tmp

            for j in range(0, 500):
                batch_images, batch_labels = mnist.validation.next_batch(
                    batch_size)
                gen_noise = np.random.randn(batch_size, noise_dims)
                feed_dict = {images: batch_images, labels: batch_labels, noise: gen_noise}
                sess.run(find_worst_d, feed_dict=feed_dict)
            # calculate the worst minmax
            batch_images, batch_labels = mnist.test.next_batch(
                batch_size)
            gen_noise = np.random.randn(batch_size, noise_dims)
            feed_dict = {images: batch_images, labels: batch_labels, noise: gen_noise}
            df_final = sess.run(d_fake_logits_tmp, feed_dict=feed_dict)  # here you need to feed z
            dr_final = sess.run(d_real_logits_tmp, feed_dict=feed_dict)
            worst_minmax = sess.run(disc_loss_calc2,
                                    feed_dict={plc_float: df_final, plc_float_r: dr_final})
            # for fixed D, find the worst G_tmp
            for j in range(0, 500):
                batch_images, batch_labels = mnist.validation.next_batch(
                    batch_size)
                gen_noise = np.random.randn(batch_size, noise_dims)
                feed_dict = {images: batch_images, labels: batch_labels, noise: gen_noise}
                sess.run(find_worst_g, feed_dict=feed_dict)
            # calculate the worst maxmin
            batch_images, batch_labels = mnist.test.next_batch(
                batch_size)
            gen_noise = np.random.randn(batch_size, noise_dims)
            feed_dict = {images: batch_images, labels: batch_labels, noise: gen_noise}
            df_final = sess.run(d_fake_logits_worst, feed_dict=feed_dict)
            dr_final = sess.run(d_real_logits_worst, feed_dict=feed_dict)
            worst_maxmin = sess.run(disc_loss_calc2,
                                    feed_dict={plc_float: df_final, plc_float_r: dr_final})

            # report the metrics
            dualitygap_score = worst_minmax - worst_maxmin
            print('The duality gap score is: ')
            print('{0:.16f}'.format(dualitygap_score))
            dg_list.append(dualitygap_score)
            print('The minmax is: ')
            print('{0:.16f}'.format(worst_minmax))
            mm.append(worst_minmax)
            classifier_list.append(accuracy)

            # if i % log_every == 0:
            #     # TODO log in tensorboard
            #     pass

            # if i % save_every == 0:
            #     # TODO save the trained model
            #     pass
steps=[i*eval_every for i in range(0, len(dg_list))]
plt.plot(steps, dg_list)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
plt.ylabel('duality gap')
plt.xlabel('steps')
# plt.show()
plt.savefig('DG_collapse.eps', format='eps', dpi=1000)
plt.savefig('DG_collapse.png', format='png')
plt.clf()
plt.plot(steps, mm)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
plt.ylabel('minimax')
plt.xlabel('steps')
# plt.show()
plt.savefig('MM_collapse.eps', format='eps', dpi=1000)
plt.savefig('MM_collapse.png', format='png')
plt.clf()
plt.plot(steps, classifier_list)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
plt.ylabel('accuracy')
plt.xlabel('steps')
# plt.show()
plt.savefig('C_collapse.eps', format='eps', dpi=1000)
plt.savefig('C_collapse.png', format='png')